#!/usr/bin/env python
import rospy
from std_msgs.msg import String
from sensor_msgs.msg import Image, CompressedImage
from cv_bridge import CvBridge, CvBridgeError
import message_filters

import cv2
import sm
import aslam_cv as acv
import aslam_cv_backend as acvb
import aslam_cameras_april as acv_april
import kalibr_common as kc

import os
import time
import numpy as np
import pylab as pl
import argparse
import sys
import getopt
import igraph

# make numpy print prettier
np.set_printoptions(suppress=True)

class CalibrationTargetDetector(object):
    def __init__(self, camera, targetConfig):

        targetParams = targetConfig.getTargetParams()
        targetType = targetConfig.getTargetType()
        
        #set up target
        if( targetType == 'checkerboard' ):
            grid = acv.GridCalibrationTargetCheckerboard(targetParams['targetRows'], 
                                                            targetParams['targetCols'], 
                                                            targetParams['rowSpacingMeters'], 
                                                            targetParams['colSpacingMeters'])
        
        elif( targetType == 'circlegrid' ):
            options = acv.CirclegridOptions(); 
            options.useAsymmetricCirclegrid = targetParams['asymmetricGrid']
            
            grid = acv.GridCalibrationTargetCirclegrid(targetParams['targetRows'],
                                                          targetParams['targetCols'], 
                                                          targetParams['spacingMeters'], 
                                                          options)
        
        elif( targetType == 'aprilgrid' ):
            grid = acv_april.GridCalibrationTargetAprilgrid(targetParams['tagRows'], 
                                                            targetParams['tagCols'], 
                                                            targetParams['tagSize'], 
                                                            targetParams['tagSpacing'])
        else:
            raise RuntimeError( "Unknown calibration target." )
        
        #setup detector
        options = acv.GridDetectorOptions() 
        options.filterCornerOutliers = True
        self.detector = acv.GridDetector(camera.geometry, grid, options)
    
class CameraChainValidator(object):
    def __init__(self, chainConfig, targetParams):
        
        self.chainConfig = chainConfig
        self.numCameras = chainConfig.numCameras()
        self.bridge = CvBridge()
        
        #initialize the cameras in the chain
        self.G = igraph.Graph(self.numCameras)
        self.monovalidators = []
        for cidx in range(0, self.numCameras):
            camConfig = chainConfig.getCameraParameters(cidx)
            
            #create a mono instance for each cam (detection and mono view)
            monovalidator = MonoCameraValidator(camConfig, targetParams)
            cv2.namedWindow(monovalidator.windowName, cv2.WINDOW_NORMAL)
            cv2.resizeWindow(monovalidator.windowName, (640, 480))
            self.monovalidators.append(monovalidator)
            
            #add edges to overlap graph
            overlaps=chainConfig.getCamOverlaps(cidx)
            for overlap in overlaps:
                #add edge if it isn't existing yet
                try:
                    edge_idx = self.G.get_eid(cidx, overlap)
                except:
                    self.G.add_edges([(cidx, overlap)])

        #prepare the rectification maps
        for edge in self.G.es:
            cidx_src = edge.source
            cidx_dest = edge.target
            edge["rect_map"] = dict()
            edge["R"] = dict()
            edge["A"] = dict()
            edge["rect_map"][cidx_src], \
            edge["rect_map"][cidx_dest], \
            edge["R"][cidx_src], \
            edge["R"][cidx_dest], \
            edge["A"][cidx_src], \
            edge["A"][cidx_dest] = self.prepareStereoRectificationMaps(cidx_src, cidx_dest)

        # generate the windows
        for edge in self.G.es:
            cidx_src = edge.source
            cidx_dest = edge.target
            windowName = "Rectified view (cam{0} and cam{1})".format(cidx_src, cidx_dest)
            cv2.namedWindow(windowName, cv2.WINDOW_NORMAL)
            cv2.resizeWindow(windowName, (2*640, 480))
                
        #register the callback for the synchronized images
        sync_sub = message_filters.ApproximateTimeSynchronizer([val.image_sub for val in self.monovalidators], 10, 0.02)
        sync_sub.registerCallback(self.synchronizedCallback)

        #initialize message throttler
        self.timeLast = 0

    def synchronizedCallback(self, *cam_msgs):
        #throttle image processing
        rate = 2 #Hz
        timeNow = time.time()
        if (timeNow-self.timeLast < 1.0/rate) and self.timeLast!=0:
            return
        self.timeLast = timeNow
         
        #process the images of all cameras
        self.observations=[]
        for cam_nr, msg in enumerate(cam_msgs):
          
            #convert image to gray single channel numpy array
            try:
                if msg._type == 'sensor_msgs/CompressedImage':
                    np_image = np.array(self.bridge.compressed_imgmsg_to_cv2(msg))
                    if len(np_image.shape) > 2 and np_image.shape[2] == 3:
                        np_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2GRAY)
                elif msg.encoding == "rgb8" or msg.encoding == "bgra8":
                    np_image = np.array(self.bridge.imgmsg_to_cv2(msg, "mono8"))
                else:
                    np_image = np.array(self.bridge.imgmsg_to_cv2(msg))
            except CvBridgeError as e:
                print("Could not convert ros message to opencv image: ", e)
                continue
            
            #get the corresponding monovalidator instance
            validator = self.monovalidators[cam_nr]
            
            #detect targets for all cams
            timestamp = acv.Time(msg.header.stamp.secs, msg.header.stamp.nsecs)
            success, observation = validator.target.detector.findTarget(timestamp, np_image)
            observation.clearImage()
            validator.obs = observation
            
            #undistort the image
            if type(validator.camera.geometry) == acv.DistortedOmniCameraGeometry:
                validator.undist_image = validator.undistorter.undistortImageToPinhole(np_image)
            else:
                validator.undist_image = validator.undistorter.undistortImage(np_image)
                            
            #generate a mono view for each cam
            validator.generateMonoview(np_image, observation, success)            
            
        #generate all rectification views
        for edge in self.G.es:
            cidx_src = edge.source
            cidx_dest = edge.target
            self.generatePairView(cidx_src, cidx_dest)
                
        cv2.waitKey(1)
    
    #returns transformation T_to_from
    def getTransformationCamFromTo(self, cidx_from, cidx_to):
        #build pose chain (target->cam0->baselines->camN)
        lowid = min((cidx_from, cidx_to))
        highid = max((cidx_from, cidx_to))
        
        T_high_low = sm.Transformation()
        for cidx in range(lowid, highid):
            baseline_HL = self.chainConfig.getExtrinsicsLastCamToHere(cidx+1)
            T_high_low = baseline_HL * T_high_low
        
        if cidx_from<cidx_to:
            T_BA = T_high_low
        else:
            T_BA = T_high_low.inverse()
        
        return T_BA

    def rectifyImages(self, imageA, mapA, imageB, mapB):
        #rectify images
        rect_image_A = cv2.remap(imageA,
                                 mapA[0],
                                 mapA[1],
                                 cv2.INTER_LINEAR)
        
        rect_image_B = cv2.remap(imageB,
                                 mapB[0],
                                 mapB[1],
                                 cv2.INTER_LINEAR)
        
        #combine the images    
        np_rect_image = np.hstack( (rect_image_A, rect_image_B) )
        
        return np_rect_image

    def generatePairView(self, camAnr, camBnr):

        #get the mono validators for each cam
        camA = self.monovalidators[camAnr]
        camB = self.monovalidators[camBnr]
        
        #get baseline between camA & camB
        T_BA = self.getTransformationCamFromTo(camAnr, camBnr)
        
        #extract the common corners for the camera in pair
        keypoints_A = camA.obs.getCornersImageFrame()
        keypoints_A_id = camA.obs.getCornersIdx()
        keypoints_B = camB.obs.getCornersImageFrame()
        keypoints_B_id = camB.obs.getCornersIdx()
        targetPoints = camA.obs.getCornersTargetFrame()
        
        #get the common corners
        common_ids = set(keypoints_A_id) & set(keypoints_B_id)
        
        #project points from camera A to cam B to compare with detection
        #and vice versa
        if len(common_ids) > 0:
            #add to new list
            commonKeypoints_A=[]; commonKeypoints_B=[]; reprojs_A=[]; reprojs_B=[]
            commonTargetPoints=[]; reproj_errs=np.array([0,0])
            
            for id in common_ids:
                commonKeypoints_A.append( keypoints_A[np.where(keypoints_A_id==id)] )
                commonKeypoints_B.append( keypoints_B[np.where(keypoints_B_id==id)] )
                commonTargetPoints.append( targetPoints[np.where(keypoints_A_id==id)] )
            
            
            for keypoint_A, keypoint_B, targetPoint in zip(commonKeypoints_A, commonKeypoints_B, commonTargetPoints):
                #reproject
                T_AW = camA.obs.T_t_c().inverse()
                T_BW_base = T_BA*T_AW
                reprojB = np.matrix(T_BW_base.C()) * np.matrix(targetPoint).reshape(3,1) + np.matrix(T_BW_base.t()).reshape(3,1)
                reprojErr_B = camB.camera.geometry.euclideanToKeypoint(reprojB) - keypoint_B
                
                T_BW = camB.obs.T_t_c().inverse()
                T_AW_base = T_BA.inverse()*T_BW
                reprojA = np.matrix(T_AW_base.C()) * np.matrix(targetPoint).reshape(3,1) + np.matrix(T_AW_base.t()).reshape(3,1)
                reprojErr_A = camA.camera.geometry.euclideanToKeypoint(reprojA) - keypoint_A
                
                #reprojection errors in original camera geomtery
                reproj_errs = np.hstack((reproj_errs, reprojErr_A.flatten()))
                reproj_errs = np.hstack((reproj_errs, reprojErr_B.flatten()))
                
            reproj_errs = reproj_errs[2:]
            reproj_errs = reproj_errs.reshape(int(reproj_errs.shape[0]/2.0), 2) # (meas, 2)
                    
        #rectify the undistorted images
        edge_idx = self.G.get_eid(camAnr, camBnr)
        edge = self.G.es[edge_idx]
        A_rect = edge["A"][camAnr] #some for cam B
        
        np_image_rect = self.rectifyImages(camA.undist_image, 
                                           edge["rect_map"][camAnr], 
                                           camB.undist_image, 
                                           edge["rect_map"][camBnr])
        
        #draw some epilines
        np_image_rect = cv2.cvtColor(np_image_rect, cv2.COLOR_GRAY2BGR)
        n=10
        for i in range(0,n):
            y = int(np_image_rect.shape[0]*i/n)
            cv2.line(np_image_rect, (0,y), (2*np_image_rect.shape[1],y),(0,255,0));
    
    
        if len(common_ids) > 0:
            #draw error statistics
            reproj_L2_errs = np.sum(np.abs(reproj_errs)**2, axis=-1)**(1./2)
            outputList = [ ( "mean_x:  ", np.mean(reproj_errs[:,0]) ),
                           ( "std_x:   ", np.std(reproj_errs[:,0]) ),
                           ( "max_y:   ", np.max(reproj_errs[:,0]) ),
                           ( "min_x:   ", np.min(reproj_errs[:,0]) ),
                           ( "", 0),
                           ( "mean_y:  ", np.mean(reproj_errs[:,1]) ),
                           ( "std_y:   ", np.std(reproj_errs[:,1]) ),
                           ( "max_y:   ", np.max(reproj_errs[:,1]) ),
                           ( "min_y:   ", np.min(reproj_errs[:,1]) ),
                           ( "", 0),
                           ( "mean_L2: ", np.mean(reproj_L2_errs) ),
                           ( "std_L2:  ", np.std(reproj_L2_errs) ),
                           ( "max_L2:  ", np.max(reproj_L2_errs) ),
                           ( "min_L2:  ", np.min(reproj_L2_errs) )      ]
    
            #print the text
            y = 20; x = 20
            for err_txt, err_val in outputList:
                fontScale = 0.75
                y += int(42*fontScale)
                 
                if err_txt == "":
                    continue
                 
                cv2.putText(np_image_rect, err_txt, (x,y), cv2.FONT_HERSHEY_SIMPLEX, fontScale=fontScale, color=(0, 0, 255), thickness=2)
                cv2.putText(np_image_rect, "{0: .4f}".format(err_val), (x+100,y), cv2.FONT_HERSHEY_SIMPLEX, fontScale=fontScale, color=(0, 0, 255), thickness=2)
        else:
 
            cv2.putText(np_image_rect, "Detection failed...", (int(np_image_rect.shape[0]/2.0),int(np_image_rect.shape[1]/5.0)), cv2.FONT_HERSHEY_SIMPLEX, fontScale=2, color=(0, 0, 255), thickness=3)

        windowName = "Rectified view (cam{0} and cam{1})".format(camAnr, camBnr)
        cv2.imshow(windowName, np_image_rect)
    
    
    def prepareStereoRectificationMaps(self, camAnr, camBnr):        
        #get the camera parameters for the undistorted cameras
        camIdealA = self.monovalidators[camAnr].undist_camera.projection().getParameters().flatten()
        camIdealB = self.monovalidators[camBnr].undist_camera.projection().getParameters().flatten()
        camIdealA = np.array([[camIdealA[0],0,camIdealA[2]], [0,camIdealA[1],camIdealA[3]], [0,0,1]])
        camIdealB = np.array([[camIdealB[0],0,camIdealB[2]], [0,camIdealB[1],camIdealB[3]], [0,0,1]])
        imageSize = (self.monovalidators[camAnr].undist_camera.projection().ru(), self.monovalidators[camAnr].undist_camera.projection().rv())
        
        #get the baseline between the cams       
        baseline_BA = self.getTransformationCamFromTo(camAnr, camBnr)
        
        ##
        #A.Fusiello, E. Trucco, A. Verri: A compact algorithm for recification of stereo pairs, 1999
        ##
        Poa = np.matrix(camIdealA) * np.hstack( (np.matrix(np.eye(3)), np.matrix(np.zeros((3,1)))) ) #use camA coords as world frame...
        Pob = np.matrix(camIdealB) * np.hstack( (np.matrix(baseline_BA.C()), np.matrix(baseline_BA.t()).T) )
        
        #optical centers (in camA's coord sys)
        c1 = -np.linalg.inv(Poa[:,0:3]) * Poa[:,3]
        c2 = -np.linalg.inv(Pob[:,0:3]) * Pob[:,3]
        
        #get "mean" rotation between cams        
        old_z_mean = (baseline_BA.C()[2,:].flatten()+sm.Transformation().T()[2,0:3])/2.0
        v1 = c1-c2 #newx-axis = direction of baseline        
        v2 = np.cross(np.matrix(old_z_mean).flatten(), v1.flatten()).T #new y axis orthogonal to new x and mean old z
        v3 = np.cross(v1.flatten(), v2.flatten()).T #orthogonal to baseline and new y
        
        #normalize
        v1 = v1/np.linalg.norm(v1)
        v2 = v2/np.linalg.norm(v2)
        v3 = v3/np.linalg.norm(v3)
        
        #create rotation matrix
        R = np.hstack((np.hstack((v1,v2)),v3)).T
        
        #new intrinsic parameters
        A = (camIdealA + camIdealB)/2.0
        
        #new projection matrices
        Pna = A * np.hstack((R, -R*c1))
        Pnb = A * np.hstack((R, -R*c2))
        
        #rectyfing transforms
        Ta = Pna[0:3,0:3] * np.linalg.inv(Poa[0:3,0:3])
        Tb = Pnb[0:3,0:3] * np.linalg.inv(Pob[0:3,0:3])
        
        Ra = R #camA=world, then to rectified coords
        Rb = R * baseline_BA.inverse().C() #to world then to rectified coords
        
        
        #create the rectification maps
        rect_map_x_a, rect_map_y_a = cv2.initUndistortRectifyMap(camIdealA, 
                                                                 np.zeros((4,1)), 
                                                                 Ra, 
                                                                 A, 
                                                                 imageSize, 
                                                                 cv2.CV_16SC2)
        
        rect_map_x_b, rect_map_y_b = cv2.initUndistortRectifyMap(camIdealB, 
                                                                 np.zeros((4,1)), 
                                                                 Rb, 
                                                                 A, 
                                                                 imageSize,
                                                                 cv2.CV_16SC2)
        
        return (rect_map_x_a, rect_map_y_a), (rect_map_x_b, rect_map_y_b), Ra, Rb, A, A
    
        
class MonoCameraValidator(object):
    def __init__(self, camConfig, targetParams):
        
        print("initializing camera geometry")
        self.camera = kc.ConfigReader.AslamCamera.fromParameters(camConfig)
        self.target = CalibrationTargetDetector(self.camera, targetConfig)
        
        #print details
        print("Camera {0}:".format(camConfig.getRosTopic()))
        camConfig.printDetails()
        
        self.topic = camConfig.getRosTopic()
        self.windowName = "Camera: {0}".format(self.topic)
    
        #register the cam topic to the message synchronizer
        if "compressed" in self.topic:
            self.image_sub = message_filters.Subscriber(self.topic, CompressedImage)
        else:
            self.image_sub = message_filters.Subscriber(self.topic, Image)
        
        #create image undistorter
        alpha = 1.0
        scale = 1.0
        self.undistorter = self.camera.undistorterType(self.camera.geometry, cv2.INTER_LINEAR, alpha, scale)
        
        if type(self.camera.geometry) == acv.DistortedOmniCameraGeometry:
            #convert omni image to pinhole image aswell
            self.undist_camera = self.undistorter.getIdealPinholeGeometry()
        else:
            self.undist_camera = self.undistorter.getIdealGeometry()
            
        #storage for reproj errors
        self.cornersImage = list()
        self.reprojectionErrors = list()
        
    
    def generateMonoview(self, np_image, observation, obs_valid):

        np_image = cv2.cvtColor(np_image, cv2.COLOR_GRAY2BGR)

        if obs_valid:
            #calculate the reprojection error statistics
            cornersImage = observation.getCornersImageFrame();
            cornersReproj = observation.getCornerReprojection(self.camera.geometry); # (meas, 2)


            
            reprojectionErrors2 = cornersImage-cornersReproj
            reprojectionErrors = np.sum(np.abs(reprojectionErrors2)**2, axis=-1)**(1./2)
            
            #save the errors for reprojection error map plotting
            self.cornersImage.append(cornersImage)
            self.reprojectionErrors.append(reprojectionErrors)
            
            outputList = [ ( "mean_x:  ", np.mean(reprojectionErrors2[:,0]) ),
                           ( "std_x:   ", np.std(reprojectionErrors2[:,0]) ),
                           ( "max_y:   ", np.max(reprojectionErrors2[:,0]) ),
                           ( "min_x:   ", np.min(reprojectionErrors2[:,0]) ),
                           ( "", 0),
                           ( "mean_y:  ", np.mean(reprojectionErrors2[:,1]) ),
                           ( "std_y:   ", np.std(reprojectionErrors2[:,1]) ),
                           ( "max_y:   ", np.max(reprojectionErrors2[:,1]) ),
                           ( "min_y:   ", np.min(reprojectionErrors2[:,1]) ),
                           ( "", 0),
                           ( "mean_L2: ", np.mean(reprojectionErrors) ),
                           ( "std_L2:  ", np.std(reprojectionErrors) ),
                           ( "max_L2:  ", np.max(reprojectionErrors) ),
                           ( "min_L2:  ", np.min(reprojectionErrors) )      ]

            #print the text
            y = 20; x = 20
            for err_txt, err_val in outputList:
                fontScale = 0.75
                y += int(42*fontScale)
                
                if err_txt == "":
                    continue
                
                cv2.putText(np_image, err_txt, (x,y), cv2.FONT_HERSHEY_SIMPLEX, fontScale=fontScale, color=(0, 0, 255), thickness=2)
                cv2.putText(np_image, "{0: .4f}".format(err_val), (x+100,y), cv2.FONT_HERSHEY_SIMPLEX, fontScale=fontScale, color=(0, 0, 255), thickness=2)
            
            #draw reprojected corners
            for px, py in zip(cornersReproj[:,0],cornersReproj[:,1]):
                #convert pixel to fixed point (opencv subpixel rendering...)
                shift = 4; radius = 0.5; thickness = 1
                px_fix =  int(px * 2**shift)
                py_fix =  int(py * 2**shift)
                radius_fix = int(radius * 2**shift)
                cv2.circle(np_image, (px_fix, py_fix), radius=radius_fix, color=(255,255,0), thickness=thickness, shift=shift, lineType=cv2.LINE_AA)

        else:
            cv2.putText(np_image, "Detection failed...", (int(np_image.shape[0]/2.0),int(np_image.shape[1]/5.0)), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1.5, color=(0, 0, 255), thickness=2)

        cv2.imshow(self.windowName, np_image)
        cv2.waitKey(1)


if __name__ == "__main__":        
    parser = argparse.ArgumentParser(description='Validate the intrinsics of a camera.')
    parser.add_argument('--target', dest='targetYaml', help='Calibration target configuration as yaml file', required=True)
    parser.add_argument('--cam', dest='chainYaml', help='Camera configuration as yaml file', required=True)
    parser.add_argument('--verbose', action='store_true', dest='verbose', help='Verbose output')
    parsed = parser.parse_args()
        
    if parsed.verbose:
        sm.setLoggingLevel(sm.LoggingLevel.Debug)
    else:
        sm.setLoggingLevel(sm.LoggingLevel.Info)

    # init the node
    rospy.init_node('kalibr_camera_validator', anonymous=True)

    #load the configuration yamls
    targetConfig = kc.ConfigReader.CalibrationTargetParameters(parsed.targetYaml)
    
    print("Initializing calibration target:".format(targetConfig.getTargetType()))
    targetConfig.printDetails()
    camchain = kc.ConfigReader.CameraChainParameters(parsed.chainYaml)
        
    #create the validator
    chain_validator = CameraChainValidator(camchain, targetConfig)
    
    #ros message loops
    try:
        rospy.spin()
    except KeyboardInterrupt:
        print("Shutting down")
    
    cv2.destroyAllWindows()
